{
 "cells": [
  {
   "cell_type": "raw",
   "id": "b32297a4",
   "metadata": {},
   "source": [
    "---\n",
    "Copyright 2025 The JAX Authors.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "380b6c4e",
   "metadata": {},
   "source": [
    "# `Ref`: mutable arrays for data plumbing and memory control\n",
    "\n",
    "JAX `Array`s are immutable, representing mathematical values. Immutability can\n",
    "make code easier to reason about, and is useful for optimized compilation,\n",
    "parallelization, rematerialization, and transformations like autodiff.\n",
    "\n",
    "But immutability is constraining too:\n",
    "* **expressiveness** --- plumbing out intermediate data or maintaining state,\n",
    "  e.g. for normalization statistics or metrics, can feel heavyweight;\n",
    "* **performance** --- it's more difficult to reason about performance, like\n",
    "  memory lifetimes and in-place updates.\n",
    "\n",
    "`Ref`s can help! They represent mutable arrays that can be read and written\n",
    "in-place. These array references are compatible with JAX transformations, like\n",
    "`jax.jit` and `jax.grad`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7909a3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "x_ref = jax.new_ref(jnp.zeros(3))  # new array ref, with initial value [0., 0., 0.]\n",
    "\n",
    "@jax.jit\n",
    "def f():\n",
    "  x_ref[1] += 1.  # indexed add-update\n",
    "\n",
    "print(x_ref)  # Ref([0., 0., 0.])\n",
    "f()\n",
    "f()\n",
    "print(x_ref)  # Ref([0., 2., 0.])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "667af649",
   "metadata": {},
   "source": [
    "The indexing syntax follows NumPy's. For a `Ref` called `x_ref`, we can\n",
    "read its entire value into an `Array` by writing `x_ref[...]`, and write its\n",
    "entire value using `x_ref[...] = A` for some `Array`-valued expression `A`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1824d27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def g(x):\n",
    "  x_ref = jax.new_ref(0.)\n",
    "  x_ref[...] = jnp.sin(x)\n",
    "  return x_ref[...]\n",
    "\n",
    "print(jax.grad(g)(1.0))  # 0.54"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff8dc074",
   "metadata": {},
   "source": [
    "`Ref` is a distinct type from `Array`, and it comes with some important\n",
    "constraints and limitations. In particular, indexed reading and writing is just\n",
    "about the *only* thing you can do with an `Ref`. References can't be passed\n",
    "where `Array`s are expected:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2191893",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ref = jax.new_ref(1.0)\n",
    "try:\n",
    "  jnp.sin(x_ref)  # error! can't do math on refs\n",
    "except Exception as e:\n",
    "  print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ab77be5",
   "metadata": {},
   "source": [
    "To do math, you need to read the ref's value first, like `jnp.sin(x_ref[...])`.\n",
    "\n",
    "So what _can_ you do with `Ref`? Read on for the details, and some useful\n",
    "recipes.\n",
    "\n",
    "### API\n",
    "\n",
    "If you've ever used\n",
    "[Pallas](https://docs.jax.dev/en/latest/pallas/quickstart.html), then `Ref`\n",
    "should look familiar. A big difference is that you can create new `Ref`s\n",
    "yourself directly using `jax.new_ref`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc852f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import Array, Ref\n",
    "\n",
    "def array_ref(init_val: Array) -> Ref:\n",
    "  \"\"\"Introduce a new reference with given initial value.\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4565356",
   "metadata": {},
   "source": [
    "`jax.freeze` is its antithesis, invalidating the given ref (so that accessing it\n",
    "afterwards is an error) and producing its final value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "049048ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze(ref: Ref) -> Array:\n",
    "  \"\"\"Invalidate given reference and produce its final value.\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62dd629d",
   "metadata": {},
   "source": [
    "In between creating and destroying them, you can perform indexed reads and\n",
    "writes on refs. You can read and write using the functions `jax.ref.get` and\n",
    "`jax.ref.swap`, but usually you'd just use NumPy-style array indexing syntax:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b34483",
   "metadata": {},
   "outputs": [],
   "source": [
    "import types\n",
    "Index = int | slice | Array | types.EllipsisType\n",
    "Indexer = Index | tuple[Index, ...]\n",
    "\n",
    "def get(ref: Ref, idx: Indexer) -> Array:\n",
    "  \"\"\"Returns `ref[idx]` for NumPy-style indexer `idx`.\"\"\"\n",
    "\n",
    "def swap(ref: Ref, idx: Indexer, val: Array) -> Array:\n",
    "  \"\"\"Performs `newval, ref[idx] = ref[idx], val` and returns `newval`.\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0ae59b8",
   "metadata": {},
   "source": [
    "Here, `Indexer` can be any NumPy indexing expression:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5f080d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ref = jax.new_ref(jnp.arange(12.).reshape(3, 4))\n",
    "\n",
    "# int indexing\n",
    "row = x_ref[0]\n",
    "x_ref[1] = row\n",
    "\n",
    "# tuple indexing\n",
    "val = x_ref[1, 2]\n",
    "x_ref[2, 3] = val\n",
    "\n",
    "# slice indexing\n",
    "col = x_ref[:, 1]\n",
    "x_ref[0, :3] = col\n",
    "\n",
    "# advanced int array indexing\n",
    "vals = x_ref[jnp.array([0, 0, 1]), jnp.array([1, 2, 3])]\n",
    "x_ref[jnp.array([1, 2, 1]), jnp.array([0, 0, 1])] = vals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd3edc22",
   "metadata": {},
   "source": [
    "As with `Array`s, indexing mostly follows NumPy behavior, except for\n",
    "out-of-bounds indexing which [behaves in the usual way for JAX\n",
    "`Array`s](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing).\n",
    "\n",
    "### Pure and impure functions\n",
    "\n",
    "A function that takes a ref as an argument (either explicitly or by lexical\n",
    "closure) is considered _impure_. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2841ccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes ref as an argument => impure\n",
    "@jax.jit\n",
    "def impure1(x_ref, y_ref):\n",
    "  x_ref[...] = y_ref[...]\n",
    "\n",
    "# closes over ref => impure\n",
    "y_ref = jax.new_ref(0)\n",
    "\n",
    "@jax.jit\n",
    "def impure2(x):\n",
    "  y_ref[...] = x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b946f6",
   "metadata": {},
   "source": [
    "If a function only uses refs internally, it is still considered _pure_. Purity\n",
    "is in the eye of the caller. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf8fb062",
   "metadata": {},
   "outputs": [],
   "source": [
    "# internal refs => still pure\n",
    "@jax.jit\n",
    "def pure1(x):\n",
    "  ref = jax.new_ref(x)\n",
    "  ref[...] = ref[...] + ref[...]\n",
    "  return ref[...]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ed144d",
   "metadata": {},
   "source": [
    "Pure functions, even those that use refs internally, are familiar: for example,\n",
    "they work with transformations like `jax.grad`, `jax.vmap`, `jax.shard_map`, and\n",
    "others in the usual way.\n",
    "\n",
    "Impure functions are sequenced in Python program order.\n",
    "\n",
    "### Restrictions\n",
    "\n",
    "`Ref`s are second-class, in the sense that there are restrictions on their\n",
    "use:\n",
    "\n",
    "* **Can't return refs** from `jit`\\-decorated functions or the bodies of\n",
    "  higher-order primitives like `jax.lax.scan`, `jax.lax.while_loop`, or\n",
    "  `jax.lax.cond`\n",
    "* **Can't pass a ref as an argument more than once** to `jit`\\-decorated\n",
    "  functions or higher-order primitives\n",
    "* **Can only `freeze` in creation scope**\n",
    "* **No higher-order refs** (refs-to-refs)\n",
    "\n",
    "For example, these are errors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61a4e501",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ref = jax.new_ref(0.)\n",
    "\n",
    "# can't return refs\n",
    "@jax.jit\n",
    "def err1(x_ref):\n",
    "  x_ref[...] = 5.\n",
    "  return x_ref  # error!\n",
    "try:\n",
    "  err1(x_ref)\n",
    "except Exception as e:\n",
    "  print(e)\n",
    "\n",
    "# can't pass a ref as an argument more than once\n",
    "@jax.jit\n",
    "def err2(x_ref, y_ref):\n",
    "  ...\n",
    "try:\n",
    "  err2(x_ref, x_ref)  # error!\n",
    "except Exception as e:\n",
    "  print(e)\n",
    "\n",
    "# can't pass and close over the same ref\n",
    "@jax.jit\n",
    "def err3(y_ref):\n",
    "  y_ref[...] = x_ref[...]\n",
    "try:\n",
    "  err3(x_ref)  # error!\n",
    "except Exception as e:\n",
    "  print(e)\n",
    "\n",
    "# can only freeze in creation scope\n",
    "@jax.jit\n",
    "def err4(x_ref):\n",
    "  jax.freeze(x_ref)\n",
    "try:\n",
    "  err4(x_ref)  # error!\n",
    "except Exception as e:\n",
    "  print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc360213",
   "metadata": {},
   "source": [
    "These restrictions exist to rule out aliasing, where two refs might refer to the\n",
    "same mutable memory, making programs harder to reason about and transform.\n",
    "Weaker restrictions would also suffice, so some of these restrictions may be\n",
    "lifted as we improve JAX's ability to verify that no aliasing is present.\n",
    "\n",
    "There are also restrictions stemming from undefined semantics, e.g. in the\n",
    "presence of parallelism or rematerialization:\n",
    "\n",
    "* **Can't `vmap` or `shard_map` a function that closes over refs**\n",
    "* **Can't apply `jax.remat`/`jax.checkpoint` to an impure function**\n",
    "\n",
    "For example, here are ways you can and can't use `vmap` with impure functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5701f96e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vmap over ref args is okay\n",
    "def dist(x, y, out_ref):\n",
    "  assert x.ndim == y.ndim == 1\n",
    "  assert out_ref.ndim == 0\n",
    "  out_ref[...] = jnp.sum((x - y) ** 2)\n",
    "\n",
    "vecs = jnp.arange(12.).reshape(3, 4)\n",
    "out_ref = jax.new_ref(jnp.zeros((3, 3)))\n",
    "jax.vmap(jax.vmap(dist, (0, None, 0)), (None, 0, 0))(vecs, vecs, out_ref)  # ok!\n",
    "print(out_ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d94d08be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vmap with a closed-over ref is not\n",
    "x_ref = jax.new_ref(0.)\n",
    "\n",
    "def err5(x):\n",
    "  x_ref[...] = x\n",
    "\n",
    "try:\n",
    "  jax.vmap(err5)(jnp.arange(3.))  # error!\n",
    "except Exception as e:\n",
    "  print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33e635e5",
   "metadata": {},
   "source": [
    "The latter is an error because it's not clear which value `x_ref` should be\n",
    "after we run `jax.vmap(err5)`.\n",
    "\n",
    "### `Ref`s and automatic differentiation\n",
    "\n",
    "Autodiff can be applied to pure functions as before, even if they use array refs\n",
    "internally. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5d32e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def pure2(x):\n",
    "  ref = jax.new_ref(x)\n",
    "  ref[...] = ref[...] + ref[...]\n",
    "  return ref[...]\n",
    "\n",
    "print(jax.grad(pure1)(3.0))  # 2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "801c3b60",
   "metadata": {},
   "source": [
    "Autodiff can also be applied to functions that take array refs as arguments, if\n",
    "those arguments are only used for plumbing and not involved in differentiation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6cd5576",
   "metadata": {},
   "outputs": [],
   "source": [
    "# error\n",
    "def err6(x, some_plumbing_ref):\n",
    "  y = x + x\n",
    "  some_plumbing_ref[...] += y\n",
    "  return y\n",
    "\n",
    "# fine\n",
    "def foo(x, some_plumbing_ref):\n",
    "  y = x + x\n",
    "  some_plumbing_ref[...] += jax.lax.stop_gradient(y)\n",
    "  return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86622dd6",
   "metadata": {},
   "source": [
    "You can combine plumbing refs with `custom_vjp` to plumb data out of the\n",
    "backward pass of a differentiated function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1f17fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, define the helper `stash_grads`:\n",
    "\n",
    "@jax.custom_vjp\n",
    "def stash_grads(grads_ref, x):\n",
    "  return x\n",
    "\n",
    "def stash_grads_fwd(grads_ref, x):\n",
    "  return x, grads_ref\n",
    "\n",
    "def stash_grads_bwd(grads_ref, g):\n",
    "  grads_ref[...] = g\n",
    "  return None, g\n",
    "\n",
    "stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0c5842e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now, use `stash_grads` to stash intermediate gradients:\n",
    "\n",
    "def f(x, grads_ref):\n",
    "  x = jnp.sin(x)\n",
    "  x = stash_grads(grads_ref, x)\n",
    "  return x\n",
    "\n",
    "grads_ref = jax.new_ref(0.)\n",
    "f(1., grads_ref)\n",
    "print(grads_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d8518e7",
   "metadata": {},
   "source": [
    "Notice `stash_grads_fwd` is returning a `Ref` here. That's a special\n",
    "allowance for `custom_vjp` fwd rules: it's really syntax for indicating which\n",
    "ref arguments should be shared by both the fwd and bwd rules. So any refs\n",
    "returned by a fwd rule must be arguments to that fwd rule.\n",
    "\n",
    "### `Ref`s and performance\n",
    "\n",
    "At the top level, when calling `jit`\\-decorated functions, `Ref`s obviate\n",
    "the need for donation, since they are effectively always donated:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64f3655c",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def sin_inplace(x_ref):\n",
    "  x_ref[...] = jnp.sin(x_ref[...])\n",
    "\n",
    "x_ref = jax.new_ref(jnp.arange(3.))\n",
    "print(x_ref.unsafe_buffer_pointer(), x_ref)\n",
    "sin_inplace(x_ref)\n",
    "print(x_ref.unsafe_buffer_pointer(), x_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a758adac",
   "metadata": {},
   "source": [
    "Here `sin_inplace` operates in-place, updating the buffer backing `x_ref` so\n",
    "that its address stays the same.\n",
    "\n",
    "Under a `jit`, you should expect array references to point to fixed buffer\n",
    "addresses, and for indexed updates to be performed in-place.\n",
    "\n",
    "**Temporary caveat:** dispatch from Python to impure `jit`\\-compiled functions\n",
    "that take `Ref` inputs is currently slower than dispatch to pure\n",
    "`jit`\\-compiled functions, since it takes a less optimized path.\n",
    "\n",
    "### `foreach`, a new way to write `scan`\n",
    "\n",
    "As you may know, `jax.lax.scan` is a loop construct with a built-in fixed access\n",
    "pattern for scanned-over inputs and outputs. The access pattern is built in for\n",
    "autodiff reasons: if we were instead to slice into immutable inputs directly,\n",
    "reverse-mode autodiff would end up creating one-hot gradients and summing them\n",
    "up, which can be asymptotically inefficient. See [Sec 5.3.3 of the Dex\n",
    "paper](https://arxiv.org/pdf/2104.05372).\n",
    "\n",
    "But reading slices of `Ref`s doesn't have this efficiency problem: when we\n",
    "apply reverse-mode autodiff, we always generate in-place accumulation\n",
    "operations. As a result, we no longer need to be constrained by `scan`'s fixed\n",
    "access pattern. We can write more flexible loops, e.g. with non-sequential\n",
    "access.\n",
    "\n",
    "Moreover, having mutation available allows for some syntax tricks, like in this\n",
    "recipe for a `foreach` decorator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11753b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax.lax import scan\n",
    "\n",
    "def foreach(*args):\n",
    "  def decorator(body):\n",
    "    return scan(lambda _, elts: (None, body(*elts)), None, args)[1]\n",
    "  return decorator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ddc7abe",
   "metadata": {},
   "outputs": [],
   "source": [
    "r = jax.new_ref(0)\n",
    "xs = jnp.arange(10)\n",
    "\n",
    "@foreach(xs)\n",
    "def ys(x):\n",
    "  r[...] += x\n",
    "  return x * 2\n",
    "\n",
    "print(r)   # Ref(45, dtype=int32)\n",
    "print(ys)  # [ 0  2  4  6  8 10 12 14 16 18]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "570970cd",
   "metadata": {},
   "source": [
    "Here, the loop runs immediately, updating `r` in-place and binding `ys` to be\n",
    "the mapped result."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
